Trie树

一.引入

如果有一道题是这样的:

给你n个字符串组成的一个词典,m次询问,每次给出一个字符串,问它是否出现在词典中。

MapMap

那么,我们改一下询问,每次给出一个字符串,问以它为前缀的词有多少个。

现在,用MapMap储存每个字符的前缀肯定会MLE+TLEMLE+TLE,我们需要引入一种新的数据结构——TrieTrie树。

二.字典树的概念

字典树,又称单词查找树(Trie树),是一种树形结构,一种哈希树的变种。典型应用是用于统计,排序和保存大量的字符串(但不仅限于字符串),所以经常被搜索引擎系统用于文本词频统计。

它的优点是:利用字符串的公共前缀来减少查询时间,最大限度地减少无谓的字符串比较,查询效率比哈希树高

三.字典树的性质

  1. 根节点不包含字符,除根节点外每一个节点都只包含一个字符。

  2. 从根节点到某一节点,路径上经过的字符连接起来,为该节点对应的字符串。

  3. 每个节点的所有子节点包含的字符都不相同。

四.字典树的操作

我们用Trie[i][j]Trie[i][j]表示节点ii的子节点其中一个字母(数字)的ASCLLASCLL码值为 j ~j~的那个节点的编号(是不是有点绕),为了方便我们计算jj时减去 'a' ( '0' )。为了不RERE尽量把ii开大。

fin[i]fin[i]表示TrieTrie树中是否有一个单词以节点 i ~i~结束。cntcnt表示TrieTrie树中的节点个数。

1.建树操作(BuildBuild

没有什么特殊操作,多组输入时清空上述数组、变量。

void Build( ) {
	memset( Trie , 0 , sizeof( Trie ) );
	memset( fin , 0 , sizeof( fin ) );
	cnt = 0;
}

2.插入操作(InsertInsert

现在我们要在字典树中插入一个字符串 , 用uu记录当前到达的节点编号。

最开始时,uu在根节点(0)。然后我们枚举字符串的每一位,如果出现在uu的子节点中,就将uu移到下一个子节点,不然就以uu为父亲新建一个子节点,再将uu移到该子节点。

将字符串插入字典树后,将最后uu所在位置(字符串的末尾)的finfin标记为1。

void Insert( char *str ) {
	int len = strlen( str ) , u = 0;
	for( int i = 0 ; i < len ; i ++ ) {
		int num = str[ i ] - 'a';
		if( !Trie[ u ][ num ] ) 
			Trie[ u ][ num ] = ++ cnt;
		u = Trie[ u ][ num ];
	}
	fin[ u ] = 1;
}

3.查询操作(FindFind

现在我们要查询一个字符串是否出现在在字典树中,同样用uu记录当前到达的节点编号。

最开始时,uu在根节点(0)。然后我们枚举字符串的每一位,如果没有出现在uu的子节点中,返回0。不然将uu移到子节点,如果整个字符串出现在uu下移的路径上,返回1。

需要注意的是,如果查询的字符串是字典树中某个串的前缀,从严格意义上讲它并不算出现在字典树中。所以当字符串遍历完后,如果uu所在位置并不是某字符串的结尾(fin[u]fin[u]=0),还是返回0。

bool Find( char *str ) {
	int len = strlen( str ) , u = 0;
	for( int i = 0 ; i < len ; i ++ ) {
		int num = str[ i ] -'a';
		if( !Trie[ u ][ num ] )
			return 0;
		u = Trie[ u ][ num ];
	}
	return fin[ u ];
}

4.删除操作(DeleteDelete)(不常用)

现在我们要在字典树中删除一个字符串 , 还是用uu记录当前到达的节点编号。

最开始时,uu在根节点(0)。然后我们枚举字符串的每一位,如果出现在uu的子节点中,就将uu移到下一个子节点,同时将uu向子节点的边删去(准确地讲应是将uu与子节点的联系删去)。注意,u的值已经改变,所以要用一个变量来保存uu

将字符串从字典树删除后,将最后uu所在位置(字符串的末尾)的finfin标记为0。

void Delete( char *str ) {
	int len = strlen( str ) , u = 0;
	for( int i = 0 ; i < len ; i ++ ) {
		int num = str[ i ] - 'a';
		if( Trie[ u ][ num ] ) {
			int t = u;
			u = Trie[ u ][ num ];
			Trie[ t ][ num ] = 0;
		}	
	}
	fin[ u ] = 0;
}

5.模板

#include <cstdio>
#include <cstring>

const int MAXN = 10000;
int Trie[ MAXN + 5 ][ 30 ] , cnt;
bool fin[ MAXN + 5 ];

void Build( ) {
	memset( Trie , 0 , sizeof( Trie ) );
	memset( fin , 0 , sizeof( fin ) );
	cnt = 0;
}
void Insert( char *str ) {
	int len = strlen( str ) , u = 0;
	for( int i = 0 ; i < len ; i ++ ) {
		int num = str[ i ] - 'a';
		if( !Trie[ u ][ num ] ) 
			Trie[ u ][ num ] = ++ cnt;
		u = Trie[ u ][ num ];
	}
	fin[ u ] = 1;
}
bool Find( char *str ) {
	int len = strlen( str ) , u = 0;
	for( int i = 0 ; i < len ; i ++ ) {
		int num = str[ i ] -'a';
		if( !Trie[ u ][ num ] )
			return 0;
		u = Trie[ u ][ num ];
	}
	return fin[ u ];
}
void Delete( char *str ) {
	int len = strlen( str ) , u = 0;
	for( int i = 0 ; i < len ; i ++ ) {
		int num = str[ i ] - 'a';
		if( Trie[ u ][ num ] ) {
			int t = u;
			u = Trie[ u ][ num ];
			Trie[ t ][ num ] = 0;
		}	
	}
	fin[ u ] = 0;
}

从以上操作可以看出,TrieTrie树对字符串的操作均为该字符串的长度。
但是,在储存字母串时,TrieTrie树最多为2626叉树。

所以,TrieTrie树是以空间换取时间

五.字典树例题

在实际问题中,我们往往需要修改一下我们的操作,来达到题目的目的。下面讲一下几道例题。

1.P2580 于是他错误的点名开始了

这道题是TrieTrie树的板题,先将nn个名字插入字典树中,对于查询的名字,在字典树中查询即可。

我们可以用fin[u]fin[u]记录uu节点的状态。fin[u]=0fin[u]=0表示字符串未出现过,fin[u]=1fin[u]=1表示第一次查询该字符串,fin[u]=2fin[u]=2表示多次查询该字符串。

#include <cstdio>
#include <cstring>

const int MAXN = 1000000;
int Trie[ MAXN + 5 ][ 30 ] , cnt;
int fin[ MAXN + 5 ];

void Build( ) {
    memset( Trie , 0 , sizeof( Trie ) );
    memset( fin , 0 , sizeof( fin ) );
    cnt = 0;
}
void Insert( char *str ) {
    int len = strlen( str ) , u = 0;
    for( int i = 0 ; i < len ; i ++ ) {
        int num = str[ i ] - 'a';
        if( !Trie[ u ][ num ] ) 
            Trie[ u ][ num ] = ++ cnt;
        u = Trie[ u ][ num ];
    }
    fin[ u ] = 1;
}
int Find( char *str ) {
    int len = strlen( str ) , u = 0;
    for( int i = 0 ; i < len ; i ++ ) {
        int num = str[ i ] -'a';
        if( !Trie[ u ][ num ] )
            return 0;
        u = Trie[ u ][ num ];
    }
    if( fin[ u ] == 1 ) {
    	fin[ u ] = 2;
    	return 1;
	}
	return fin[ u ];
}
void Delete( char *str ) {
    int len = strlen( str ) , u = 0;
    for( int i = 0 ; i < len ; i ++ ) {
        int num = str[ i ] - 'a';
        if( Trie[ u ][ num ] ) {
            int t = u;
            u = Trie[ u ][ num ];
            Trie[ t ][ num ] = 0;
        }   
    }
    fin[ u ] = 0;
}

int n,m;
char s[ 55 ];
int main( ) {
	scanf("%d",&n);
	for( int i = 1 ; i <= n ; i ++ ) {
		scanf("%s",s);
		Insert( s );
	}
	scanf("%d",&m);
	for( int i = 1 ; i <= m ; i ++ ) {
		scanf("%s",s);
		int f = Find( s );
		if( f == 1 )
			printf("OK\n");
		if( f == 2 )
			printf("REPEAT\n");
		if( f == 0 )
			printf("WRONG\n");
	}
	return 0;
}

2.P2922 秘密消息Secret Message

这道题还是TrieTrie树的板题,先将nn个信息插入字典树中,对于查询的密码,在字典树中查询即可。

插入很好办,我们现在来讨论查询操作。我们记s1s_1为任意一个信息,s2s_2是待匹配的密码。tot[u]tot[u]表示有多少个字符串经过点uufin[u]fin[u]表示有多少个字符串以点uu结尾。

那么会出现三种情况:

1.s1s_1s2s_2的前缀,累加路径上所有的fin[u]fin[u]即可。

2.s1=s2s_1=s_2。其实就是情况1,s1s_1结束时的点一定与s2s_2结束时的点相同,也就包含在finfin中。

3.s2s_2s1s_1的前缀,如果单词最后一个字符所在节点为vv,我们很容易看出tot[v]tot[v]就代表该单词是多少字符串的前缀。那么我们只需要加上tot[v]tot[v]就好了。

但是,tot[v]tot[v]会包含以vv节点结尾的单词数,所以结果须减去fin[v]fin[v]

那么,一道蓝题就被我们解决了,附上代码:

#include <cstdio>
#include <cstring>

const int MAXN = 500000;
int n,m,len;
int Trie[ MAXN + 5 ][ 2 ] , cnt;
int tot[ MAXN + 5 ] , fin[ MAXN + 5 ];

void Insert( char *str , int len ) {
    int u = 0;
    for( int i = 0 ; i < len ; i ++ ) {
        int num = str[ i ] - '0';
        if( !Trie[ u ][ num ] ) 
            Trie[ u ][ num ] = ++ cnt;
        u = Trie[ u ][ num ];
        tot[ u ] ++;
    }
    fin[ u ] ++;
}
int Find( char *str , int len ) {
    int u = 0 , Ans = 0;
    for( int i = 0 ; i < len ; i ++ ) {
        int num = str[ i ] -'0';
        if( !Trie[ u ][ num ] )
        	return Ans;
        u = Trie[ u ][ num ];
        Ans += fin[ u ];
    }
    return Ans - fin[ u ] + tot[ u ];
}

int s1;
char s[ 10005 ];
int main( ) {
	scanf("%d %d",&m,&n);
	for( int i = 1 ; i <= m ; i ++ ) {
		scanf("%d",&len);
		for( int j = 0 ; j < len ; j ++ ) {
			scanf("%d",&s1);
			s[ j ] = s1 + '0';
		}
		Insert( s , len );
	}
	for( int i = 1 ; i <= n ; i ++ ) {
		scanf("%d",&len);
		for( int j = 0 ; j < len ; j ++ ) {
			scanf("%d",&s1);
			s[ j ] = s1 + '0';
		}
		printf("%d\n",Find( s , len ));
	}
	return 0;
}